from ActualCausal.Inference.General.counterfactual import compute_splitting_losses
from ActualCausal.Train.train_utils import compute_expectile
from ActualCausal.Train.Regularization.embedding_reg import compute_embedding_losses
from ActualCausal.Inference.General.null import compute_null_consistency_losses
from ActualCausal.Inference.General.attention import compute_attention_loss

def apply_regularizers(base_loss, args, params, model, batch, results, skip_names=[]):
    if args.train.loss_type == "expectile" and "expectile" not in skip_names: loss= compute_expectile(base_loss, ord=args.train.expectile.expt_ord, threshold=params.expt_threshold, expt_tau = args.trian.expectile.expt_tau)
    else: loss = base_loss.mean() # default to an unadulterated loss  args.train.loss_type == "mean": 
    # TODO: add more regularizers, make sure skipping is properly implemented in train_passive
    # TODO: return the regularization components for debugging and adaptive rates
    if args.inter.regularization.splitting.splitting_lambda[0] >= 0 and "splitting" not in skip_names: loss += compute_splitting_losses(args, params, model, batch).split_loss.mean()
    if args.inter.regularization.embedding.embed_reg_lambda >= 0 and "embedding" not in skip_names: loss += compute_embedding_losses(args, params, model, batch, results).mean()
    if args.inter.regularization.embedding.mask_embed_reg_lambda >= 0 and "embedding" not in skip_names: loss += compute_embedding_losses(args, params, model, batch, results, use_masked=True).mean()
    if args.inter.regularization.null_consistency.null_reg_lambda >= 0 and "null_consist" not in skip_names: loss += compute_null_consistency_losses(args, params, model, batch, results).mean()
    if args.inter.regularization.attention.attn_reg_lambda >= 0 and "attention" not in skip_names: loss += compute_attention_loss(model, batch, args, params, results, keep_all = False)
    return loss